import numpy as np
import PIL.Image
import io
import base64
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
import time


def extract_objects_and_descriptions(text):
    # Handle both literal newlines and escaped newlines
    text = text.replace('\\n', '\n')

    # Initialize empty lists to store results
    objects = []
    descriptions = []

    # Split the text by newlines
    lines = text.split('\n')

    current_main_object = None
    current_object = None
    current_description = []

    # Iterate through each line
    for line in lines:
        # Skip empty lines and background header lines
        if not line.strip() or line.strip().startswith('### Background'):
            continue

        # Check if this is a main section header (starts with ### but not ###Background)
        if line.strip().startswith('### ') and not line.strip().startswith('### Background'):
            # If we already have an object and description, save them
            if current_object and current_description:
                objects.append(current_object)
                descriptions.append(' '.join(current_description))

            # Check if this is a "Most Important Objects" type header
            if "Important Objects" in line:
                current_main_object = line.strip().replace('### ', '')
                current_object = None
            else:
                # Start a new regular object
                current_main_object = None
                current_object = line.strip().replace('### ', '')
            current_description = []

        # Check if this is a subsection header (starts with ####)
        elif line.strip().startswith('#### '):
            # If we already have an object and description, save them
            if current_object and current_description:
                objects.append(current_object)
                descriptions.append(' '.join(current_description))

            # Start a new sub-object
            current_object = line.strip().replace('#### ', '')
            current_description = []

        # Check if this is a numbered item
        elif line.strip() and any(line.strip().startswith(f"{i}.") for i in range(1, 10)):
            # Extract the text after the number and period
            description_text = line.strip().split('. ', 1)[1] if '. ' in line.strip() else line.strip()
            current_description.append(description_text)

    # Add the last object and description if they exist
    if current_object and current_description:
        objects.append(current_object)
        descriptions.append(' '.join(current_description))

    return objects, descriptions


def gpt_response(object, description, max_attempts=10, retry_delay=5):

    prompt = f'''
    Given the object and its description, rewrite the description to make sure every word in the sentence gets a PropBank-style annotation, and then assign PropBank-style annotation to each word. Here are the requirements for each sentence: (1). Make sure every sentence only has one predicate and clearly focuses on the predicate, you can use adverbs to replace other verbs.
    (2). Make sure the object is assigned ARG0 or ARG1. 
    (3). Make sure the assigned PropBank-style annotation is selected from the following list [ARG0, ARG1, ARG2, ARGM-LOC (Locative), ARGM-MNR (Manner), ARGM-ADV (Adverbial), ARGM-DIR (Direction), ARGM-PRP (Purpose)]
    (4). Avoid passive voice and copular + present participle construction. Use Explicit Active Structure. 
    (5) Make sure the rewritten sentence is simple and you can remove unnecessary words but do not add things that never appear in the original sentence. 
    For example:
    [ARG0: The man][V: stands][ARGM-LOC: near the sidewalk edge]. [ARG0: The man][V: stands][ARGM-LOC: close to the building wall]
    
    Now, given the object "{object}" and the description "{description}", rewrite the description with EXACTLY the example format:
    '''

    model = ChatOpenAI(model="gpt-4o-mini",
                       openai_api_key="xxx",
                       temperature=0,
                       max_tokens=None,
                       timeout=None,
                       max_retries=2)

    messages = HumanMessage(
        content=[
            {"type": "text", "text": prompt},
        ],
    )

    attempts = 0

    while attempts < max_attempts:
        attempts += 1
        try:
            response = model.invoke([messages])
            if response and response.content and len(response.content.strip()) > 0:
                return response.content
            else:
                print(f"Empty response received on attempt {attempts}. Retrying...")
        except Exception as e:
            print(f"Error on attempt {attempts}: {str(e)}")

        if attempts < max_attempts:
            print(f"Waiting {retry_delay} seconds before retry...")
            time.sleep(retry_delay)

    raise Exception(f"Failed to get response after {max_attempts} attempts")


def process_item(object, description):
    if object == "Error" or description == "Error":
        print("Error in processing item")
        return "Error"
    response = gpt_response(object, description)
    return response


import concurrent.futures
def get_data_responses(object_list, description_list):
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(process_item, object_list[i], description_list[i]) for i in range(len(object_list))]
        response_list = [future.result() for future in futures]
    return response_list


def process_cur_data(description):
    obj1_list = []
    obj2_list = []
    des1_list = []
    des2_list = []

    for i in range(len(description)):
        obj_list, des_list = extract_objects_and_descriptions(description[i])
        try:
            obj_1, obj_2 = obj_list[0], obj_list[1]
            des_1, des_2 = des_list[0], des_list[1]
        except:
            print("Error in extracting objects and descriptions for index: ", i)
            obj_1, obj_2 = "Error", "Error"
            des_1, des_2 = "Error", "Error"

        obj1_list.append(obj_1)
        obj2_list.append(obj_2)
        des1_list.append(des_1)
        des2_list.append(des_2)
    print("Extract objects and descriptions done!")
    propbank_obj1_des = get_data_responses(obj1_list, des1_list)
    propbank_obj2_des = get_data_responses(obj2_list, des2_list)
    print(len(propbank_obj1_des), len(propbank_obj2_des))
    print("Get propbank done!")
    return propbank_obj1_des, propbank_obj2_des



save_dir = "data"


# train
split = "train"
print("Processing: ", split)
response_list = np.load(f'{save_dir}/description.npz', allow_pickle=True)
description = response_list['response_list']
propbank_obj1_des, propbank_obj2_des = process_cur_data(description)
# save
np.savez(f'{save_dir}/propbank_des.npz', propbank_obj1_des=propbank_obj1_des,
         propbank_obj2_des=propbank_obj2_des)

# test
split = "test"
# save
response_list = np.load(f'{save_dir}description.npz', allow_pickle=True)
description = response_list['response_list']
propbank_obj1_des, propbank_obj2_des = process_cur_data(description)
np.savez(f'{save_dir}/propbank_des.npz', propbank_obj1_des=propbank_obj1_des,
         propbank_obj2_des=propbank_obj2_des)

